{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "A cute little demo showing the simplest usage of minGPT. Configured to run fine on Macbook Air in like a minute."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch.utils.data import Dataset\n",
    "from torch.utils.data.dataloader import DataLoader\n",
    "from mingpt.utils import set_seed\n",
    "set_seed(3407)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "\n",
    "class SortDataset(Dataset):\n",
    "    \"\"\" \n",
    "    Dataset for the Sort problem. E.g. for problem length 6:\n",
    "    Input: 0 0 2 1 0 1 -> Output: 0 0 0 1 1 2\n",
    "    Which will feed into the transformer concatenated as:\n",
    "    input:  0 0 2 1 0 1 0 0 0 1 1\n",
    "    output: I I I I I 0 0 0 1 1 2\n",
    "    where I is \"ignore\", as the transformer is reading the input sequence\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(self, split, length=6, num_digits=3):\n",
    "        assert split in {'train', 'test'}\n",
    "        self.split = split\n",
    "        self.length = length\n",
    "        self.num_digits = num_digits\n",
    "    \n",
    "    def __len__(self):\n",
    "        return 10000 # ...\n",
    "    \n",
    "    def get_vocab_size(self):\n",
    "        return self.num_digits\n",
    "    \n",
    "    def get_block_size(self):\n",
    "        # the length of the sequence that will feed into transformer, \n",
    "        # containing concatenated input and the output, but -1 because\n",
    "        # the transformer starts making predictions at the last input element\n",
    "        return self.length * 2 - 1\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        \n",
    "        # use rejection sampling to generate an input example from the desired split\n",
    "        while True:\n",
    "            # generate some random integers\n",
    "            inp = torch.randint(self.num_digits, size=(self.length,), dtype=torch.long)\n",
    "            # half of the time let's try to boost the number of examples that \n",
    "            # have a large number of repeats, as this is what the model seems to struggle\n",
    "            # with later in training, and they are kind of rate\n",
    "            if torch.rand(1).item() < 0.5:\n",
    "                if inp.unique().nelement() > self.length // 2:\n",
    "                    # too many unqiue digits, re-sample\n",
    "                    continue\n",
    "            # figure out if this generated example is train or test based on its hash\n",
    "            h = hash(pickle.dumps(inp.tolist()))\n",
    "            inp_split = 'test' if h % 4 == 0 else 'train' # designate 25% of examples as test\n",
    "            if inp_split == self.split:\n",
    "                break # ok\n",
    "        \n",
    "        # solve the task: i.e. sort\n",
    "        sol = torch.sort(inp)[0]\n",
    "\n",
    "        # concatenate the problem specification and the solution\n",
    "        cat = torch.cat((inp, sol), dim=0)\n",
    "\n",
    "        # the inputs to the transformer will be the offset sequence\n",
    "        x = cat[:-1].clone()\n",
    "        y = cat[1:].clone()\n",
    "        # we only want to predict at output locations, mask out the loss at the input locations\n",
    "        y[:self.length-1] = -1\n",
    "        return x, y\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1 -1\n",
      "0 -1\n",
      "1 -1\n",
      "0 -1\n",
      "0 -1\n",
      "0 0\n",
      "0 0\n",
      "0 0\n",
      "0 0\n",
      "0 1\n",
      "1 1\n"
     ]
    }
   ],
   "source": [
    "# print an example instance of the dataset\n",
    "train_dataset = SortDataset('train')\n",
    "test_dataset = SortDataset('test')\n",
    "x, y = train_dataset[0]\n",
    "for a, b in zip(x,y):\n",
    "    print(int(a),int(b))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "number of parameters: 0.09M\n"
     ]
    }
   ],
   "source": [
    "# create a GPT instance\n",
    "from mingpt.model import GPT\n",
    "\n",
    "model_config = GPT.get_default_config()\n",
    "model_config.model_type = 'gpt-nano'\n",
    "model_config.vocab_size = train_dataset.get_vocab_size()\n",
    "model_config.block_size = train_dataset.get_block_size()\n",
    "model = GPT(model_config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "running on device cuda\n"
     ]
    }
   ],
   "source": [
    "# create a Trainer object\n",
    "from mingpt.trainer import Trainer\n",
    "\n",
    "train_config = Trainer.get_default_config()\n",
    "train_config.learning_rate = 5e-4 # the model we're using is so small that we can go a bit faster\n",
    "train_config.max_iters = 2000\n",
    "train_config.num_workers = 0\n",
    "trainer = Trainer(train_config, model, train_dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "iter_dt 0.00ms; iter 0: train loss 1.06407\n",
      "iter_dt 18.17ms; iter 100: train loss 0.14712\n",
      "iter_dt 18.70ms; iter 200: train loss 0.05315\n",
      "iter_dt 19.65ms; iter 300: train loss 0.04404\n",
      "iter_dt 31.64ms; iter 400: train loss 0.04724\n",
      "iter_dt 18.43ms; iter 500: train loss 0.02521\n",
      "iter_dt 19.83ms; iter 600: train loss 0.03352\n",
      "iter_dt 19.58ms; iter 700: train loss 0.00539\n",
      "iter_dt 18.72ms; iter 800: train loss 0.02057\n",
      "iter_dt 18.26ms; iter 900: train loss 0.00360\n",
      "iter_dt 18.50ms; iter 1000: train loss 0.00788\n",
      "iter_dt 20.64ms; iter 1100: train loss 0.01162\n",
      "iter_dt 18.63ms; iter 1200: train loss 0.00963\n",
      "iter_dt 18.32ms; iter 1300: train loss 0.02066\n",
      "iter_dt 18.40ms; iter 1400: train loss 0.01739\n",
      "iter_dt 18.37ms; iter 1500: train loss 0.00376\n",
      "iter_dt 18.67ms; iter 1600: train loss 0.00133\n",
      "iter_dt 18.38ms; iter 1700: train loss 0.00179\n",
      "iter_dt 18.66ms; iter 1800: train loss 0.00079\n",
      "iter_dt 18.48ms; iter 1900: train loss 0.00042\n"
     ]
    }
   ],
   "source": [
    "def batch_end_callback(trainer):\n",
    "    if trainer.iter_num % 100 == 0:\n",
    "        print(f\"iter_dt {trainer.iter_dt * 1000:.2f}ms; iter {trainer.iter_num}: train loss {trainer.loss.item():.5f}\")\n",
    "trainer.set_callback('on_batch_end', batch_end_callback)\n",
    "\n",
    "trainer.run()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# now let's perform some evaluation\n",
    "model.eval();"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train final score: 5000/5000 = 100.00% correct\n",
      "test final score: 5000/5000 = 100.00% correct\n"
     ]
    }
   ],
   "source": [
    "def eval_split(trainer, split, max_batches):\n",
    "    dataset = {'train':train_dataset, 'test':test_dataset}[split]\n",
    "    n = train_dataset.length # naugy direct access shrug\n",
    "    results = []\n",
    "    mistakes_printed_already = 0\n",
    "    loader = DataLoader(dataset, batch_size=100, num_workers=0, drop_last=False)\n",
    "    for b, (x, y) in enumerate(loader):\n",
    "        x = x.to(trainer.device)\n",
    "        y = y.to(trainer.device)\n",
    "        # isolate the input pattern alone\n",
    "        inp = x[:, :n]\n",
    "        sol = y[:, -n:]\n",
    "        # let the model sample the rest of the sequence\n",
    "        cat = model.generate(inp, n, do_sample=False) # using greedy argmax, not sampling\n",
    "        sol_candidate = cat[:, n:] # isolate the filled in sequence\n",
    "        # compare the predicted sequence to the true sequence\n",
    "        correct = (sol == sol_candidate).all(1).cpu() # Software 1.0 vs. Software 2.0 fight RIGHT on this line haha\n",
    "        for i in range(x.size(0)):\n",
    "            results.append(int(correct[i]))\n",
    "            if not correct[i] and mistakes_printed_already < 3: # only print up to 5 mistakes to get a sense\n",
    "                mistakes_printed_already += 1\n",
    "                print(\"GPT claims that %s sorted is %s but gt is %s\" % (inp[i].tolist(), sol_candidate[i].tolist(), sol[i].tolist()))\n",
    "        if max_batches is not None and b+1 >= max_batches:\n",
    "            break\n",
    "    rt = torch.tensor(results, dtype=torch.float)\n",
    "    print(\"%s final score: %d/%d = %.2f%% correct\" % (split, rt.sum(), len(results), 100*rt.mean()))\n",
    "    return rt.sum()\n",
    "\n",
    "# run a lot of examples from both train and test through the model and verify the output correctness\n",
    "with torch.no_grad():\n",
    "    train_score = eval_split(trainer, 'train', max_batches=50)\n",
    "    test_score  = eval_split(trainer, 'test',  max_batches=50)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "input sequence  : [[0, 0, 2, 1, 0, 1]]\n",
      "predicted sorted: [[0, 0, 0, 1, 1, 2]]\n",
      "gt sort         : [0, 0, 0, 1, 1, 2]\n",
      "matches         : True\n"
     ]
    }
   ],
   "source": [
    "# let's run a random given sequence through the model as well\n",
    "n = train_dataset.length # naugy direct access shrug\n",
    "inp = torch.tensor([[0, 0, 2, 1, 0, 1]], dtype=torch.long).to(trainer.device)\n",
    "assert inp[0].nelement() == n\n",
    "with torch.no_grad():\n",
    "    cat = model.generate(inp, n, do_sample=False)\n",
    "sol = torch.sort(inp[0])[0]\n",
    "sol_candidate = cat[:, n:]\n",
    "print('input sequence  :', inp.tolist())\n",
    "print('predicted sorted:', sol_candidate.tolist())\n",
    "print('gt sort         :', sol.tolist())\n",
    "print('matches         :', bool((sol == sol_candidate).all()))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.10.4 64-bit",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.4"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "3ad933181bd8a04b432d3370b9dc3b0662ad032c4dfaa4e4f1596c548f763858"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
